{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\h5py\\__init__.py:34: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import os\n",
    "import random\n",
    "import math\n",
    "import sys\n",
    "from PIL import Image\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tfrecord文件已经生成共 500 张 张\n"
     ]
    }
   ],
   "source": [
    "# 验证集数量\n",
    "_NUM_TEST = 500\n",
    "# 随机数种子\n",
    "_RANDOM_SEED = 0\n",
    "# 数据集文件夹\n",
    "_DATA_DIR = \"c:/learn/captcha/images\"\n",
    "# TFrecord文件夹\n",
    "TFRECORD_DIR = \"c:/learn/captcha\"\n",
    "\n",
    "# 获取所有验证码图片\n",
    "def _get_filenames_and_classes(data_dir):\n",
    "    photo_filenames = []\n",
    "    for filename in os.listdir(data_dir):\n",
    "        # 获取文件路径\n",
    "        path = os.path.join(data_dir, filename)\n",
    "        photo_filenames.append(path)\n",
    "    return photo_filenames\n",
    "\n",
    "def int64_frature(values):\n",
    "    if not isinstance(values, (tuple, list)):\n",
    "        values = [values]\n",
    "    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))\n",
    "\n",
    "def bytes_feature(values):\n",
    "    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))\n",
    "\n",
    "def image_to_tfexample(image_data, label0, label1, label2, label3):\n",
    "    \n",
    "    return tf.train.Example(features=tf.train.Features(feature={\n",
    "        \"image\"  : bytes_feature(image_data), \n",
    "        \"label0\" : int64_frature(label0), \n",
    "        \"label1\" : int64_frature(label1), \n",
    "        \"label2\" : int64_frature(label2), \n",
    "        \"label3\" : int64_frature(label3),\n",
    "    }))\n",
    "\n",
    "def _convert_dataset(split_name, filenames, dataset_dir):\n",
    "    assert split_name in [\"train\", \"test\"]\n",
    "    \n",
    "    with tf.Session() as sess:\n",
    "        # 定义tfrecord文件的路径和名字\n",
    "        output_filename = os.path.join(TFRECORD_DIR, split_name + \".tfrecords\")\n",
    "        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:\n",
    "            \n",
    "            for i, filename in enumerate(filenames):\n",
    "                \n",
    "                try:\n",
    "                    print(\"正在转换第 %d 张图片,共 %d 张\" %(i+1, len(filenames)),end=\"\\r\")\n",
    "                    \n",
    "                    # 读取图片\n",
    "                    image_data = Image.open(filename)\n",
    "                    # 根据模型结构resize\n",
    "                    image_data.resize((224, 224))\n",
    "                    # 灰度化\n",
    "                    image_data = np.array(image_data.convert(\"L\"))\n",
    "                    # 将图片转化成bytes\n",
    "                    image_data = image_data.tobytes()\n",
    "                    \n",
    "                    # 获取label\n",
    "                    labels = filename.split(\"\\\\\")[-1][0:4]\n",
    "                    num_labels = []\n",
    "                    for j in range(4):\n",
    "                        num_labels.append(int(labels[j]))\n",
    "                    \n",
    "                    # 生成protocl数据类型\n",
    "                    example = image_to_tfexample(image_data, \n",
    "                                                 num_labels[0], \n",
    "                                                num_labels[1], \n",
    "                                                num_labels[2], \n",
    "                                                num_labels[3])\n",
    "                    tfrecord_writer.write(example.SerializeToString())\n",
    "                except IOError as e:\n",
    "                    print(\"Could not read \", filename)\n",
    "                    print(\"Error:\", e)\n",
    "                    print(\"Skip it\")\n",
    "\n",
    "# 获取所有图片\n",
    "photo_filenames = _get_filenames_and_classes(_DATA_DIR)\n",
    "\n",
    "# 把数据切分成训练集和测试机，并且打乱\n",
    "random.seed(_RANDOM_SEED)\n",
    "random.shuffle(photo_filenames)\n",
    "trainning_filenames = photo_filenames[_NUM_TEST:]\n",
    "testing_filenames = photo_filenames[:_NUM_TEST]\n",
    "\n",
    "# 数据转换\n",
    "_convert_dataset(\"train\", trainning_filenames, _DATA_DIR)\n",
    "_convert_dataset(\"test\", testing_filenames, _DATA_DIR)\n",
    "\n",
    "print(\"tfrecord文件已经生成\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
